# -*- coding: utf-8 -*-  
'''
预训练bert

Created on 2021年9月8日
@author: luoyi
'''
import utils.conf as conf
from utils.iexicon import WordsWarehouse
from data.sohu_thuc_news.bert_dataset import BertPreTFRecordDataset
from models.bert.nets import BertModel


#    初始化词库
WordsWarehouse().instance().load_pkl()
print('初始化词库. ', WordsWarehouse().instance().words_count())


#    准备数据集
count_train = conf.DATASET_SOHU_THUCNEWS.get_pre_training_train_count()
batch_size = conf.BERT.get_batch_size()
steps_per_epoch = count_train // batch_size
epochs=conf.BERT.get_epochs()

bert_pre_ds = BertPreTFRecordDataset()
db_train = bert_pre_ds.tensor_db_train(batch_size, epochs=epochs)
db_val = bert_pre_ds.tensor_db_val(batch_size, epochs=epochs)


#    准备模型
bert = BertModel(name='bert', 
                 learning_rate=conf.BERT.get_learning_rate(),
                 input_shape=(None, 2, conf.BERT.get_max_sen_len()),
                 auto_assembling=True,
                 is_build=True,
                 
                 vocab_size=WordsWarehouse.instance().words_count(),
                 max_sen_len=conf.BERT.get_max_sen_len(),  
                 max_sen=3,
                 n_block=conf.BERT.get_n_block(),
                 n_head=conf.BERT.get_n_head_attention(),
                 d_model=conf.BERT.get_d_model(),
                 f_model=conf.BERT.get_f_model(),
                 dropout_rate=conf.BERT.get_dropout_rate(),
                 
                 lamud_loss_pre_nsp=conf.BERT.get_lamud_loss_pre_nsp(),
                 lamud_loss_pre_mlm=conf.BERT.get_lamud_loss_pre_mlm(),
                 )


#    喂数据
bert.train_tensor_db(db_train, db_val, 
                     steps_per_epoch, batch_size, epochs, 
                     auto_save_weights_after_traind=True, auto_save_weights_dir=conf.BERT.get_model_save_weights_path(), 
                     auto_learning_rate_schedule=True, 
                     auto_tensorboard=True, auto_tensorboard_dir=conf.BERT.get_tensorboard_dir_path())
